fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136
fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136Tracin wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
…address computation In the MLA decode/prefill KV load path, `indices[q] * ckv_stride_page` was computed in 32-bit because `IdType` is `int32_t` and `*_stride_page` is `uint32_t`; the product wraps modulo 2^32 before any widening to `int64_t` (Hopper) or pointer arithmetic (FA2). For large page pools (e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384) the true product exceeds 2^32 and the kernel reads the wrong page, producing all-zero outputs. Cast the selected page index to `int64_t` at all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughWidened KV page-index arithmetic to 64-bit in two CUDA attention sources to prevent uint32 overflow during stride multiplications; added a regression test that exercises page-index overflow cases for MLA decode kernels. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request addresses potential 32-bit integer overflows in KV cache offset calculations within mla.cuh and mla_hopper.cuh by casting page indices to int64_t. Feedback suggests that the entire offset calculation should be promoted to 64-bit to prevent overflows in subsequent additions and to improve future-proofing.
|
Nice fix — cast is in the right place and all three call sites are covered. One suggestion: add a minimal regression test that forces You don't need a huge KV cache — a sparse |
|
@qsang-nv Thanks for the review! However I do not get how |
|
@Tracin You are right, we do need a real address, however, it can be smaller than the script you provided in the issue. The overflow triggers as soon as Overflow threshold:
Allocation needed: Tensors are contiguous, so to have
Fits on most modern GPUs. Gate with |
Exercises the int64 widening added in a716f93 by running a 26-page MLA decode with page indices starting at 262144 — the smallest index that makes `indices[q] * ckv_stride_page` overflow uint32 for a contiguous [*, 32, 512] cache (stride(0) = 16384). Compared against a reference run with the same data and stride but non-overflowing indices; pre-fix, the big-index run silently reads the wrong page and produces garbage output with no crash. Self-skips when free VRAM is below 12 GiB (the big cache alone is ~9.66 GiB). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@qsang-nv I see. Test is added! |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/attention/test_mla_decode_kernel.py (1)
571-578: Avoid zero-filling the entire ~9.6 GiB cache.Only the wrapped low pages and overflow suffix are observed by this regression. Using
empty()and zeroing the wrapped pages keeps the pre-fix failure deterministic without the full-cache memset cost.Proposed refactor
- ckv_big = torch.zeros( + ckv_big = torch.empty( total_num_pages, page_size, head_dim_ckv, device=device, dtype=dtype ) - kpe_big = torch.zeros( + kpe_big = torch.empty( total_num_pages, page_size, head_dim_kpe, device=device, dtype=dtype ) + ckv_big[:NUM_PAGES].zero_() + kpe_big[:NUM_PAGES].zero_() ckv_big[OVERFLOW_START:] = real_ckv kpe_big[OVERFLOW_START:] = real_kpe🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_mla_decode_kernel.py` around lines 571 - 578, The test currently allocates ckv_big and kpe_big with torch.zeros(...) which zero-fills the entire ~9.6 GiB cache; instead allocate with torch.empty(total_num_pages, page_size, head_dim_ckv/ head_dim_kpe, device=device, dtype=dtype) for ckv_big/kpe_big and only zero the observed regions: the wrapped low pages slice and the overflow suffix (use OVERFLOW_START and the indices for the wrapped pages derived from total_num_pages/page_size and wrap logic) then assign real_ckv/real_kpe into ckv_big[OVERFLOW_START:] and kpe_big[OVERFLOW_START:]; this preserves deterministic pre-fix behavior while avoiding a full-cache memset.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/attention/test_mla_decode_kernel.py`:
- Around line 512-522: Add the existing architecture-error skip wrapper to the
parametrized test by decorating test_mla_page_index_uint32_overflow_regression
with `@skip_on_gpu_arch_error` (imported from flashinfer.utils or the test
utilities), so that GPU-architecture-related exceptions raised for either
backend are handled; keep the existing `@pytest.mark.parametrize`("backend",
["fa2", "fa3"]) and the in-body SM90a guard (is_sm90a_supported) for fa3, but
place the `@skip_on_gpu_arch_error` decorator directly above the test function
definition to match other MLA decode tests.
---
Nitpick comments:
In `@tests/attention/test_mla_decode_kernel.py`:
- Around line 571-578: The test currently allocates ckv_big and kpe_big with
torch.zeros(...) which zero-fills the entire ~9.6 GiB cache; instead allocate
with torch.empty(total_num_pages, page_size, head_dim_ckv/ head_dim_kpe,
device=device, dtype=dtype) for ckv_big/kpe_big and only zero the observed
regions: the wrapped low pages slice and the overflow suffix (use OVERFLOW_START
and the indices for the wrapped pages derived from total_num_pages/page_size and
wrap logic) then assign real_ckv/real_kpe into ckv_big[OVERFLOW_START:] and
kpe_big[OVERFLOW_START:]; this preserves deterministic pre-fix behavior while
avoiding a full-cache memset.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 090ae1b8-3ef6-4e48-befe-4b6b5ff513ed
📒 Files selected for processing (1)
tests/attention/test_mla_decode_kernel.py
| @pytest.mark.parametrize("backend", ["fa2", "fa3"]) | ||
| def test_mla_page_index_uint32_overflow_regression(backend): | ||
| # Regression for the int64 widening in mla.cuh / mla_hopper.cuh | ||
| # (`indices[q] * ckv_stride_page`). For a contiguous | ||
| # [num_pages, page_size, head_dim_ckv] cache with page_size=32 and | ||
| # head_dim_ckv=512, ckv_stride_page = 16384 elements. Any page index | ||
| # >= 2^32 / 16384 = 262144 makes the multiplication overflow uint32 and | ||
| # — pre-fix — silently wraps to the wrong page (no crash, wrong output). | ||
| device = torch.device("cuda:0") | ||
| if backend == "fa3" and not is_sm90a_supported(device): | ||
| pytest.skip("fa3 backend requires SM90a") |
There was a problem hiding this comment.
Add the existing architecture-error skip wrapper to this backend-parametrized test.
fa3 is gated, but fa2 or backend dispatch can still raise an unsupported-architecture error on some runners. Match the existing MLA decode test by wrapping this regression with @skip_on_gpu_arch_error.
Proposed fix
+@skip_on_gpu_arch_error
`@pytest.mark.parametrize`("backend", ["fa2", "fa3"])
def test_mla_page_index_uint32_overflow_regression(backend):As per coding guidelines, tests/**/*.py: “Skip test execution on unsupported GPU architectures using flashinfer.utils check functions (is_sm90a_supported(), is_sm100a_supported(), etc.) or API methods like api_name.is_compute_capability_supported(cc)”.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/attention/test_mla_decode_kernel.py` around lines 512 - 522, Add the
existing architecture-error skip wrapper to the parametrized test by decorating
test_mla_page_index_uint32_overflow_regression with `@skip_on_gpu_arch_error`
(imported from flashinfer.utils or the test utilities), so that
GPU-architecture-related exceptions raised for either backend are handled; keep
the existing `@pytest.mark.parametrize`("backend", ["fa2", "fa3"]) and the in-body
SM90a guard (is_sm90a_supported) for fa3, but place the `@skip_on_gpu_arch_error`
decorator directly above the test function definition to match other MLA decode
tests.
|
/bot run |
📌 Description
In the MLA decode/prefill KV load path,
indices[q] * ckv_stride_pagewas computed in 32-bit becauseIdTypeisint32_tand*_stride_pageisuint32_t; the product wraps modulo 2^32 before any widening toint64_t(Hopper) or pointer arithmetic (FA2). For large page pools (e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384) the true product exceeds 2^32 and the kernel reads the wrong page, producing all-zero outputs. Cast the selected page index toint64_tat all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit.🔍 Related Issues
#3130
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit